
cd('.')
addpath('matlab functions')

D = readtable('NC Data/SCC_4to5.csv');

OUT = struct();

for sbj = {'math','reading'}
    vamdl = load(sprintf('output/main/vam_%s_matching.mat',sbj{1}));

    tchcoef = vamdl.tchcoef;
    lvl = startsWith(tchcoef,'yr');

    coefCOV = vamdl.VAmomentEstimates.cov;
    Z_cov = cov(D{:,tchcoef});
    Z_mean = mean(D{:,tchcoef});

    total_var_level = Z_mean*coefCOV*Z_mean' ... 
        + trace(coefCOV(lvl,lvl)*Z_cov(lvl,lvl)) ...
        + 2*trace(coefCOV(lvl,~lvl)*Z_cov(~lvl,lvl));

    total_var_match = trace(coefCOV(~lvl,~lvl)*Z_cov(~lvl,~lvl));

    total_var = total_var_level + total_var_match;

    total_var = [total_var; total_var_level; total_var_match];
    OUT.(sbj{1}).total_var = array2table([total_var 100*total_var./total_var(1)], ...
        'RowNames',{'total','level','match'}, ...
        'VariableNames',{['var_' sbj{1}] , ['pct_' sbj{1}]});

    tchcoef_split = cellfun(@(x) strsplit(x,'X'),tchcoef,'unif',0);
    tchcoef_split{strcmp(tchcoef,sprintf('lag_%sscore_sq',sbj{1}))} = repmat({sprintf('lag_%sscore',sbj{1})},[1 2]);
    tchvars = unique(cat(2,tchcoef_split{:}));
    tchvars_match = tchvars(~startsWith(tchvars,'yr'));
    partial_var_match = struct();
    for v = tchvars_match 
        Z = grouptransform(D(:,tchvars),[],@(x) mean(x),setdiff(tchvars,v));
        Z = cellfun(@(x) prod(Z{:,x},2),tchcoef_split,'unif',0);
        Z = cat(2,Z{:});
        Z_cov = cov(Z);
        partial_var_match.(v{1}) = trace(coefCOV(~lvl,~lvl)*Z_cov(~lvl,~lvl));
    end
    partial_var_match = struct2table(partial_var_match);
    partial_var_match = renamevars(partial_var_match,sprintf('lag_%sscore',sbj{1}),'lag_score');
    partial_var_match = rows2vars(partial_var_match);
    partial_var_match.Properties.RowNames = partial_var_match.OriginalVariableNames;
    partial_var_match = removevars(partial_var_match,'OriginalVariableNames');
    partial_var_match.Properties.VariableNames = {['var_' sbj{1}]};

    OUT.(sbj{1}).partial_var = partial_var_match;

    vadist = vamdl.VAdistEstimates;
    vadist.mean = bsxfun(@minus,vadist.mean,vadist.mean*vadist.pr');
    gm = gmdistribution(vadist.mean',vadist.cov,vadist.pr);
    va_draw = random(gm,3e6);
    Z_cov = cov(D{:,tchcoef});

    within_var = sum((va_draw(:,~lvl)*Z_cov(~lvl,~lvl)).*va_draw(:,~lvl),2);
    assert(abs(total_var_match - mean(within_var))<2e-5)

    OUT.(sbj{1}).within = prctile(sqrt(within_var),[10 25 50 75 90]);
   
end

total_var = innerjoin(OUT.math.total_var,OUT.reading.total_var,'Keys','Row');
total_var(:,'name') = deal({''});
total_var('total','name') = {'Total'};
total_var('level','name') = {'Level'};
total_var('match','name') = {'Matching'};
total_var = total_var({'total','level','match'}, ...
    {'name','var_math','pct_math','var_reading','pct_reading'});

latextbl(total_var,... 
    'tables and figures/tables/va_total_decomp.tex', ...
    {'%s','%0.4f','%0.1f','%0.4f','%0.1f'});

partial_var = innerjoin(OUT.math.partial_var,OUT.reading.partial_var,'Keys','Row');
partial_var(:,'name') = deal({''});
partial_var('female','name') = {'Female'};
partial_var('lag_score','name') = {'Lag Score'};
partial_var('lep','name') = {'LEP'};
partial_var('poverty','name') = {'FRL'};
partial_var('race_black','name') = {'Black'};
partial_var.name = strcat(partial_var.name);
partial_var.pct_math = 100*partial_var.var_math./sum(partial_var.var_math);
partial_var.pct_reading = 100*partial_var.var_reading./sum(partial_var.var_reading);
partial_var = partial_var({'lag_score','female','race_black','poverty','lep'}, ...
    {'name','var_math','pct_math','var_reading','pct_reading'});
latextbl(partial_var,... 
    'tables and figures/tables/va_match_decomp.tex', ...
    {'%s','%0.4f','%0.1f','%0.4f','%0.1f'});

